[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461
[NO MERGE] Integrate CK varlen cross attention for small-seq (s_q=1, s_kv<=16)#461VeeraRajasekhar wants to merge 9 commits intodevfrom
Conversation
Integrate the CK team's unfused variable-length attention HIP kernels from
varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized
path for specialized cross-attention (Q length 1, KV length 2-16, large
batch)..
- Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under
fused_attn_rocm/: declarations and implementation adapted from
varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output;
grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over
max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.
- Add fused_attn_smallseq.cpp to the ROCm fused-attn build in
transformer_engine/common/CMakeLists.txt.
- In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when
max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size
== 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host
max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q,
h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2)
call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence
count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen,
output_S shape, workspace size, and small-seq fwd so varlen kernel indexing
matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen
kernel expects sequence-level batch).
- In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query
(workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host
max_seqlen_kv; on real run call get_runtime_max_seqlen then
fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for
get_runtime_max_seqlen, workspace size, and small-seq bwd.
- Reuse softmax LSE auxiliary buffer for attention weights in the small-seq
path (forward write, backward read);
- JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and
kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads,
q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux
buffer matches C++ attention-weights convention.
- Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py
(parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD,
SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in
C++.
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
|
Let's make this PR work for jax extension first. Later we can support pytorch. One key difference btw jax and pytorch fused-attn dispatch is that pytorch can calculate, request, and allocate softmax_aux, workspace during runtime with actual cu_seqlen_q/kv data. However, in jax extension, softmax_aux and workspace calculation is done in TransformerEngine/transformer_engine/jax/cpp_extensions/attention.py Lines 364 to 375 in b685686 General guideline: |
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp
Outdated
Show resolved
Hide resolved
| NVTE_CHECK(workspace != nullptr, "small-seq bwd requires workspace."); | ||
|
|
||
| float sqr_dk_scale = attn_scale; | ||
| hipStream_t hip_stream = reinterpret_cast<hipStream_t>(stream); |
There was a problem hiding this comment.
Probably no need for this cast. cudaStream_t will be hipified correctly to hipStream_t
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
…port to small-seq kernels
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
- tests/jax: CK small-seq tests use fixture to set/restore NVTE_FUSED_ATTN_CK_SMALLSEQ=1; parametrize dtype (BF16/FP16) and add sequence-packing cases (2048-2-4, 2-4096-8192); when env set, num_segments_per_seq = max_seqlen_q for THD else 2. - JAX attention.py: THD softmax shape/dtype uses small-seq path only when env=1, else original layout - JAX attention.cpp: Added env guard - fused_attn_smallseq: Use TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT for fwd/bwd; add FP16 (__half) support; fix __half*float with T(scale).
| const T* V_ptr = static_cast<const T*>(devPtrV); | ||
| T* O_ptr = static_cast<T*>(devPtrO); | ||
| T* attn_workspace = static_cast<T*>(attn_weights_buffer); | ||
| const int* cu_kv = static_cast<const int*>(devPtrCuSeqlensKV); |
There was a problem hiding this comment.
Will it run into issues if we don't pass in cu_seqlen_q/cu_seqlen_q_padded?
For example, if there are several empty segments for q/kv but for all non-empty ones, s_q always equal to 1?
|
|
||
| size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| static_cast<uint64_t>(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream)); | ||
| size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( |
There was a problem hiding this comment.
Here we don't pass cu_seqlen_padded into runtime max seqlen check. What if max_seqlen without padding satisfy the ck kernel condition but with padding they do not? Can ck kernel handle those corner cases?
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp
Outdated
Show resolved
Hide resolved
| ) | ||
| ck_smallseq_softmax_aux_size = ( | ||
| batch_size * attn_heads * q_max_seqlen | ||
| * min(kv_max_seqlen, 16) * 2 |
There was a problem hiding this comment.
Looking at the implementation, we only support kv_max_seqlen<=16 right? So should this be checked via an assertion instead of enforced via min?
There was a problem hiding this comment.
We cannot keep that cause, we care about run_time_max_seq_len, here we don't know the run_time_max_seqlen, for example
TransformerEngine/tests/jax/test_fused_attn.py
Lines 1280 to 1281 in 4537cce
TransformerEngine/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
Lines 954 to 955 in 4537cce
There was a problem hiding this comment.
So this test would break for any case where the runtime_max_seqlen_kv is actually >16?
| if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size: | ||
| softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) | ||
| softmax_dtype = dtypes.canonicalize_dtype(q_dtype) | ||
| else: | ||
| softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16)) | ||
| softmax_dtype = dtypes.canonicalize_dtype(q_dtype) |
There was a problem hiding this comment.
| if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size: | |
| softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) | |
| softmax_dtype = dtypes.canonicalize_dtype(q_dtype) | |
| else: | |
| softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, min(kv_max_seqlen, 16)) | |
| softmax_dtype = dtypes.canonicalize_dtype(q_dtype) | |
| softmax_shape = (*batch_shape, attn_heads, q_max_seqlen) | |
| softmax_dtype = dtypes.canonicalize_dtype(q_dtype) | |
| if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size: | |
| softmax_shape += (1,) | |
| else: | |
| softmax_shape += (min(kv_max_seqlen, 16),) |
b5c5fb7 to
c6e0eae
Compare
| ) | ||
| ck_smallseq_softmax_aux_size = ( | ||
| batch_size * attn_heads * q_max_seqlen | ||
| * min(kv_max_seqlen, 16) * 2 |
There was a problem hiding this comment.
So this test would break for any case where the runtime_max_seqlen_kv is actually >16?
| run 1 test_fused_attn.py | ||
| NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass | ||
| run 1 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # skip smallseq in normal flow | ||
| XLA_FLAGS='--xla_gpu_graph_level=0' run 1 test_fused_attn.py -k 'test_ck_unfused_smallseq_backend' # CK smallseq with GPU graph disabled |
There was a problem hiding this comment.
For our new rocm7.2 image, the xla cudagraph disabling need to use
XLA_FLAGS="--xla_gpu_enable_command_buffer="
| generates random segments. For KV: always generates random segments. | ||
| """ | ||
| num_segments_per_seq = self.max_seqlen_q | ||
| if self.max_seqlen_q == 1: |
There was a problem hiding this comment.
Will it run into problems if we call generate_random_segment_ids directly when self.max_seqlen_q==1?
| self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( | ||
| self.batch_size, | ||
| self.max_seqlen_kv, | ||
| if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": |
There was a problem hiding this comment.
Maybe put the small_seq into self.config to replace the checking with ENV?
| def ck_smallseq_env(monkeypatch): | ||
| """Enable CK small-seq path and disable XLA GPU graphs for these tests.""" | ||
| if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""): | ||
| pytest.skip("Test must be run with XLA_FLAGS='--xla_gpu_graph_level=0'") |
There was a problem hiding this comment.
Not sure the new XLA_FLAG for cudagraph is due to the change of rocm or jax. If it's with the jax change, we can use a jax version check
| void* workspace_next = workspace; | ||
|
|
||
| const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); | ||
| if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { |
There was a problem hiding this comment.
Let's add another filter s_q !=s_kv here
| void* workspace_next = workspace; | ||
|
|
||
| const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); | ||
| if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { |
There was a problem hiding this comment.
Same here, let's add s_q!=s_kv here
| if config.qkv_layout.is_thd(): | ||
| softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1) | ||
| # THD only: check env; run small-seq logic only when enabled | ||
| if os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") != "1": |
There was a problem hiding this comment.
Also filter it with q_max_seqlen != kv_max_seqlen
| ) # 2 bytes for bf16/fp16 | ||
| if ck_standard_softmax_aux_size >= ck_smallseq_softmax_aux_size: | ||
| softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) | ||
| softmax_dtype = dtypes.canonicalize_dtype(q_dtype) |
There was a problem hiding this comment.
softmax_dtype for old ck flow is fp32, I recall?
| auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); | ||
|
|
||
| const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); | ||
| if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") { |
There was a problem hiding this comment.
Add the q_max_seqlen != kv_max_seqlen filter here as well
Integrate the CK team's unfused variable-length attention HIP kernels from varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized path for specialized cross-attention (Q length 1, KV length 2-16, large batch)..
Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under fused_attn_rocm/: declarations and implementation adapted from varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output; grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.
Add fused_attn_smallseq.cpp to the ROCm fused-attn build in transformer_engine/common/CMakeLists.txt.
In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q, h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2) call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen, output_S shape, workspace size, and small-seq fwd so varlen kernel indexing matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen kernel expects sequence-level batch).
In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host max_seqlen_kv; on real run call get_runtime_max_seqlen then fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for get_runtime_max_seqlen, workspace size, and small-seq bwd.
Reuse softmax LSE auxiliary buffer for attention weights in the small-seq path (forward write, backward read);
JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux buffer matches C++ attention-weights convention.
Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD, SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in C++.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: